Add Device context manager for temporary device switching#1597
Add Device context manager for temporary device switching#1597Andy-Jost wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Closes NVIDIA#1586. Adds __enter__/__exit__ to Device so it can be used as a context manager that saves the current CUDA context on entry and restores it on exit. Uses cuCtxGetCurrent/cuCtxSetCurrent (not push/pop) for interoperability with the runtime API. Saved contexts are stored on a per-thread stack (_tls._ctx_stack) so nested and reentrant usage works correctly. Also adds teardown to mempool_device_x2/x3 fixtures to clean up residual contexts between tests. Co-authored-by: Cursor <cursoragent@cursor.com>
|
/ok to test f02b730 |
|
cpcloud
left a comment
There was a problem hiding this comment.
Solid implementation — the design choices (stateless restoration via cuCtxGetCurrent/cuCtxSetCurrent, per-thread stack, peek-then-pop in __exit__) are all correct and well-reasoned. Tests are thorough. Two items to address before merge:
__exit__can mask the caller's exception — ifcuCtxSetCurrentraises during unwinding from a user exception, the CUDA error replaces the original (it's still in__context__, but most users won't look there). See inline comment for a suggested fix.- Missing docs / release notes — this adds a new public API entry point on
Devicebut the PR doesn't updateinteroperability.rst(the "Current device/context" section should mention the context manager alongsideset_current()),getting-started.rst, or0.7.x-notes.rst. Even for a draft, having these roughed in makes the feature discoverable.
| with nogil: | ||
| HANDLE_RETURN(cydriver.cuCtxSetCurrent(prev_ctx)) | ||
| _tls._ctx_stack.pop() | ||
| return False |
There was a problem hiding this comment.
If cuCtxSetCurrent raises here while an exception is already propagating from the with block, the CUDA error replaces the user's original exception. The original ends up in e.__context__ but is easy to miss.
Consider guarding:
def __exit__(self, exc_type, exc_val, exc_tb):
cdef uintptr_t prev_ctx_ptr = _tls._ctx_stack.pop()
cdef cydriver.CUcontext prev_ctx = <cydriver.CUcontext><void*>prev_ctx_ptr
try:
with nogil:
HANDLE_RETURN(cydriver.cuCtxSetCurrent(prev_ctx))
except Exception:
if exc_type is None:
raise # no active exception, surface the CUDA error
# else: swallow the restore failure to preserve the original exception;
# the stack entry is already popped so the next __exit__ won't retry.
return FalseThis also simplifies the peek-then-pop dance — just pop eagerly, since a failed cuCtxSetCurrent with a context obtained from cuCtxGetCurrent moments earlier is essentially unrecoverable anyway.
(_graphics.pyx has the same pattern, so this is a pre-existing issue there too, but worth getting right here.)
| cdef cydriver.CUcontext prev_ctx | ||
| with nogil: | ||
| HANDLE_RETURN(cydriver.cuCtxGetCurrent(&prev_ctx)) | ||
| if not hasattr(_tls, '_ctx_stack'): |
There was a problem hiding this comment.
Nit: the try/except AttributeError pattern (EAFP) is marginally faster after the first call per thread, and is the more common pattern elsewhere in _device.pyx (see Device_ensure_tls_devices):
try:
_tls._ctx_stack
except AttributeError:
_tls._ctx_stack = []Not blocking — just a consistency suggestion.
|
|
||
| # ============================================================================ | ||
| # Device Context Manager Tests | ||
| # ============================================================================ |
There was a problem hiding this comment.
Nit: consider putting _get_current_context() in conftest.py — it's generally useful for any context-related test, and other test files may want it as multi-device / context-switching tests grow.
leofang
left a comment
There was a problem hiding this comment.
Blocking accidental merge of this PR before a design review/survey happens, as discussed offline and described in the linked issue.
Summary
Closes #1586. Adds
__enter__/__exit__toDeviceso it can be used as a context manager that temporarily activates a device and restores the previous CUDA context on exit.Changes
cuda/core/_device.pyx: Added__enter__and__exit__methods toDevice. On enter, queries the current context viacuCtxGetCurrentand saves it on a per-thread stack (_tls._ctx_stack), then callsset_current(). On exit, restores the saved context viacuCtxSetCurrent. Uses peek-then-pop ordering so the stack is not corrupted ifcuCtxSetCurrentraises.tests/test_device.py: Added 12 tests covering basic usage, context restoration, exception safety, same-device nesting, deep nesting, multi-GPU nesting,set_current()inside awithblock, device usability after exit, device initialization, and thread safety (3 threads on 3 GPUs).tests/conftest.py: Added teardown tomempool_device_x2andmempool_device_x3fixtures to clean up residual contexts between tests.Design
__enter__queries the actual CUDA driver state rather than maintaining a Python-side device cache. This ensures correct interoperability with other libraries (PyTorch, CuPy) that usecudaSetDevice/cuCtxSetCurrent.Devicesingleton), so nested and reentrant usage works correctly.cuCtxGetCurrent/cuCtxSetCurrent: Consistent withset_current()and the runtime API model. Does not usecuCtxPushCurrent/cuCtxPopCurrent.Test Coverage
All tests pass locally on single-GPU (L40) and multi-GPU (3x RTX PRO 6000 Blackwell) machines. Stress-tested with 20 randomized iterations via
pytest-repeat+pytest-randomlywith no ordering issues.Made with Cursor